# -*- coding: utf-8 -*-
"""User Agent class"""
from typing import Union, Type
from typing import Optional

from pydantic import BaseModel

from ._agent import AgentBase
from ._user_input import TerminalUserInput, UserInputBase
from ..message import Msg


class UserAgent(AgentBase):
    """User agent class"""

    _input_method: UserInputBase = TerminalUserInput()
    """The user input method, you can exchange it with you own input method
    by calling the `register_input_method` method."""

    def __init__(
        self,
        name: str = "User",
        input_hint: str = "User Input: ",
        require_url: bool = False,
    ) -> None:
        """Initialize a UserAgent object.

        Args:
            name (`str`, defaults to `"User"`):
                The name of the agent. Defaults to "User".
            input_hint (`str`, defaults to `"User Input: "`):
                The hint of the input. Defaults to "User Input: ".
            require_url (`bool`, defaults to `False`):
                Whether the agent requires user to input a URL. Defaults to
                False. The URL can lead to a website, a file,
                or a directory. It will be added into the generated message
                in field `url`.
        """
        super().__init__(name=name)

        self.name = name
        self.input_hint = input_hint
        self.require_url = require_url

    def override_input_method(
        self,
        input_method: UserInputBase,
    ) -> None:
        """Override the input method of the current UserAgent instance.

        Args:
            input_method (`UserInputBase`):
                The callable input method, which should be an object of a
                class that inherits from `UserInputBase`.
        """
        self._input_method = input_method

    @classmethod
    def override_class_input_method(
        cls,
        input_method: UserInputBase,
    ) -> None:
        """Override the input method of the current UserAgent instance.

        Args:
            input_method (`UserInputBase`):
                The callable input method, which should be an object of a
                class that inherits from `UserInputBase`.
        """
        cls._input_method = input_method

    def reply(
        self,
        x: Optional[Union[Msg, list[Msg]]] = None,
        structured_output: Optional[Type[BaseModel]] = None,
    ) -> Msg:
        """Receive input message(s) and generate a reply message from the user.
        The input source can be set by calling `override_input_method` and
        `override_class_input_method` methods.

        For example, if you want to receive user input from your own source,
        you can create a class that inherits from `UserInputBase` and
        override the input method of the `UserAgent` class.

        Args:
            x (`Optional[Union[Msg, list[Msg]]]`, defaults to `None`):
                The input message(s) to the agent, which also can be omitted if
                the agent doesn't need any input.
            structured_output (`Optional[Type[BaseModel]]`, defaults
             to `None`):
                A child class of `pydantic.BaseModel` that defines the
                structured output format. If provided, the user will be
                prompted to fill in the required fields. The output will be
                placed in the `metadata` field of the output message.

        Returns:
            `Msg`: The output message generated by the agent.
        """
        if self.memory:
            self.memory.add(x)

        # Get the input from the specified input method.
        input_data = self._input_method(
            agent_id=self.agent_id,
            agent_name=self.name,
            structured_schema=structured_output,
        )

        blocks_input = input_data.blocks_input
        if (
            blocks_input
            and len(blocks_input) == 1
            and blocks_input[0].get("type") == "text"
        ):
            # Turn blocks_input into a string if only one text block exists
            blocks_input = blocks_input[0].get("text")

        msg = Msg(
            self.name,
            content=blocks_input,
            role="user",
            metadata=input_data.structured_input,
        )

        self.speak(msg)

        if self.memory:
            self.memory.add(msg)

        return msg
